{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <i>CatBoost learning to rank on Microsoft dataset</i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from catboost import CatBoost, Pool, MetricVisualizer\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ranking problem\n",
    "\n",
    "Let's introduce a notation:\n",
    "* Let $Q = \\{q_1, \\dots, q_n\\}$ be the set of groups\n",
    "* $D_q = \\{d_{q1}, \\dots, d_{qm}\\}$ -- set of objects retrieved for a group $q$\n",
    "* $L_q = \\{l_{q1}, \\dots, l_{qm}\\}$ -- relevance labels for the objects from the set $D_q$\n",
    "\n",
    "Every object $d_{qi}$ is represented in the vector space of features, describing the associatinos between the group and the object.\n",
    "\n",
    "So every group is associated with set of objects. For example, group is a query and object is a document if we are ranking documents for a search query.\n",
    "\n",
    "The goal is to learn the ranking function $f = f(d_{qi})$, such that the ranking of objects $d_{qi}$ for all groups from $Q$ based on their scores $x_{qi} = f(d_{qi})$, is as close as possible to the ideal ranking from the editorial judgements $l_{qi}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ranking quality metrics:\n",
    "* __Precision__\n",
    "    $$ \\mbox{P}=\\frac{|\\{\\mbox{relevant docs}\\}\\cap\\{\\mbox{retrieved docs}\\}|}{|\\{\\mbox{retrieved docs}\\}|} $$\n",
    "* __Recall__\n",
    "    $$ \\mbox{R}=\\frac{|\\{\\mbox{relevant docs}\\}\\cap\\{\\mbox{retrieved docs}\\}|}{|\\{\\mbox{relevant docs}\\}|} $$\n",
    "    \n",
    "    Notation $@k$ means that metric is calculated on the first $k$ documents from ranking list.\n",
    "\n",
    "    For example, if 1,2,5,7,9 is the ranks of relevant documents (enumerations starts from number 1) from ten retrivied then $P@5$ will be $\\frac{3}{5}$.\n",
    "\n",
    "* __Mean average precision (MAP)__\n",
    "    $$\\frac{1}{|Q|}\\sum_{q \\in Q} \\frac{1}{|\\mbox{relevant docs in } D_q|} \\sum_{k} P@k(q) \\times rel(q, k) $$\n",
    "    \n",
    "    Where $rel(q, k)$ is a relevance label of the document at k-th position in our ranking of $D_q$. This metric calculates average precision for a query weighted with document relevances and then calculate mean between all queries.\n",
    "    \n",
    "* __Discounted cumulative gain (DCG)__\n",
    "    $$\\sum_{k=1}^{mq} \\frac{2 ^ {l_{qk}}}{\\log_2(k+1)}$$\n",
    "    \n",
    "    This metric takes into account user behavior: user attention is high on the top and then nonlinear decrease to the end.\n",
    "    \n",
    "* __NDCG__ - normalized DCG = DCG $~ / ~$ IDCG, where IDCG is a maximum possible value of DCG with given set of relevance labels.\n",
    "\n",
    "* __AverageGain__ - represents the average value of the label values for objects with the defined top  label values.\n",
    "\n",
    "* __[PFound](https://tech.yandex.com/catboost/doc/dg/references/pfound-docpage/#pfound)__\n",
    "    \n",
    "More on wiki: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)\n",
    "\n",
    "Parameter $@k$ for every metric can be specified through metric parameter \"top\", for example \"NDCG:top=10\", would mean NDCG@10."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download part of [MSRank](https://www.microsoft.com/en-us/research/project/mslr/) dataset from CatBoost datasets storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from catboost.datasets import msrank_10k\n",
    "train_df, test_df = msrank_10k()\n",
    "\n",
    "X_train = train_df.drop([0, 1], axis=1).values\n",
    "y_train = train_df[0].values\n",
    "queries_train = train_df[1].values\n",
    "\n",
    "X_test = test_df.drop([0, 1], axis=1).values\n",
    "y_test = test_df[0].values\n",
    "queries_test = test_df[1].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Number of documents__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "num_documents = X_train.shape[0]\n",
    "print(num_documents)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Number of features__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "X_train.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Relevance labels statistics__\n",
    "\n",
    "0 - irrelevant, 1 - highly relevant. Table represents number of documents for each value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "Counter(y_train).items()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For calculation such metrics as NDCG and PFound relevances should be in segment \\[0,1\\]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_relevance = np.max(y_train)\n",
    "y_train /= max_relevance\n",
    "y_test /= max_relevance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Number of queries__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_queries = np.unique(queries_train).shape[0]\n",
    "num_queries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creation of CatBoost pools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train = Pool(\n",
    "    data=X_train,\n",
    "    label=y_train,\n",
    "    group_id=queries_train\n",
    ")\n",
    "\n",
    "test = Pool(\n",
    "    data=X_test,\n",
    "    label=y_test,\n",
    "    group_id=queries_test\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### You can also create pools from files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_dir = './msrank'\n",
    "\n",
    "if not os.path.exists(data_dir):\n",
    "    os.makedirs(data_dir)\n",
    "\n",
    "train_file = os.path.join(data_dir, 'train.csv')\n",
    "test_file = os.path.join(data_dir, 'test.csv')\n",
    "\n",
    "train_df.to_csv(train_file, index=False, header=False)\n",
    "test_df.to_csv(test_file, index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "description_file = os.path.join(data_dir, 'dataset.cd')\n",
    "with open(description_file, 'w') as f:\n",
    "    f.write('0\\tLabel\\n')\n",
    "    f.write('1\\tQueryId\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Pool(data=train_file, column_description=description_file, delimiter=',')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <span style=\"color:#ce2029\">Attention:</span> all objects in dataset must be grouped by group_id\n",
    "\n",
    "For example, if the dataset consits of five documents \n",
    "\\[d1, d2, d3, d4, d5\\] with corresponding queries \\[q1, q2, q2, q1, q2\\] then the dataset should be look like:\n",
    "\n",
    "$$\\begin{pmatrix}\n",
    "    d_1, q_1, f_1\\\\\n",
    "    d_4, q_1, f_4\\\\\n",
    "    d_2, q_2, f_2\\\\\n",
    "    d_3, q_2, f_3\\\\\n",
    "    d_5, q_2, f_5\\\\\n",
    "\\end{pmatrix} \\hspace{6px} \\texttt{or} \\hspace{6px}\n",
    "\\begin{pmatrix}\n",
    "    d_2, q_2, f_2\\\\\n",
    "    d_3, q_2, f_3\\\\\n",
    "    d_5, q_2, f_5\\\\\n",
    "    d_1, q_1, f_1\\\\\n",
    "    d_4, q_1, f_4\\\\\n",
    "\\end{pmatrix}$$\n",
    "\n",
    "where $f_i$ is feature vector of i-th document."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reducing problem to machine learning task\n",
    "\n",
    "The first and simplest idea is to try predicting document relevance $l_q$ minimizing RMSE.\n",
    "\n",
    "$$\\frac{1}{N}\\sqrt{ \\sum_q \\sum_{d_{qk}} \\left(f(d_{qk}) - l_{qk} \\right)^2 }$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "default_parameters = {\n",
    "    'iterations': 2000,\n",
    "    'custom_metric': ['NDCG', 'PFound', 'AverageGain:top=10'],\n",
    "    'verbose': False,\n",
    "    'random_seed': 0,\n",
    "}\n",
    "\n",
    "parameters = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def fit_model(loss_function, additional_params=None, train_pool=train, test_pool=test):\n",
    "    parameters = deepcopy(default_parameters)\n",
    "    parameters['loss_function'] = loss_function\n",
    "    parameters['train_dir'] = loss_function\n",
    "    \n",
    "    if additional_params is not None:\n",
    "        parameters.update(additional_params)\n",
    "        \n",
    "    model = CatBoost(parameters)\n",
    "    model.fit(train_pool, eval_set=test_pool, plot=True)\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets train the simplest model and also demonstrate precision/recall metrics from introduction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model = fit_model('RMSE', {'custom_metric': ['PrecisionAt:top=10', 'RecallAt:top=10', 'MAP:top=10']})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Group weights parameter\n",
    "Suppose we know that some queries are more important than others for us.<br/>\n",
    "The word \"importance\" used here in terms of accuracy or quality of CatBoost prediction for given queries.<br/>\n",
    "You can pass this additional information for learner using a ``group_weights`` parameter.<br/>\n",
    "Under the hood, CatBoost uses this weights in loss function simply multiplying it on a group summand.<br/>\n",
    "So the bigger weight $\\rightarrow$ the more attention for query.<br/>\n",
    "Let's show an example of training procedure with random query weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def create_weights(queries):\n",
    "    query_set = np.unique(queries)\n",
    "    query_weights = np.random.uniform(size=query_set.shape[0])\n",
    "    weights = np.zeros(shape=queries.shape)\n",
    "    \n",
    "    for i, query_id in enumerate(query_set):\n",
    "        weights[queries == query_id] = query_weights[i]\n",
    "    \n",
    "    return weights\n",
    "    \n",
    "\n",
    "train_with_weights = Pool(\n",
    "    data=X_train,\n",
    "    label=y_train,\n",
    "    group_weight=create_weights(queries_train),\n",
    "    group_id=queries_train\n",
    ")\n",
    "\n",
    "test_with_weights = Pool(\n",
    "    data=X_test,\n",
    "    label=y_test,\n",
    "    group_weight=create_weights(queries_test),\n",
    "    group_id=queries_test\n",
    ")\n",
    "\n",
    "fit_model(\n",
    "    'RMSE', \n",
    "    additional_params={'train_dir': 'RMSE_weigths'}, \n",
    "    train_pool=train_with_weights,\n",
    "    test_pool=test_with_weights\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A special case: top-1 prediction\n",
    "\n",
    "Someday you may face with a problem $-$ you will need to predict the top one most relevant object for a given query.<br/>\n",
    "For this purpose CatBoost has a mode called __QuerySoftMax__.\n",
    "\n",
    "Suppose our dataset contain a binary target: 1 $-$ mean best document for a query, 0 $-$ others.<br/>\n",
    "We will maximize the probability of being the best document for given query.<br/>\n",
    "MSRANK dataset doesn't contain binary labels, but for example of method __QuerySoftMax__ we convert it to that format,<br/> choosing a best document for every query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def get_best_documents(labels, queries):\n",
    "    query_set = np.unique(queries)\n",
    "    num_queries = query_set.shape[0]\n",
    "    by_query_arg_max = {query: -1 for query in query_set}\n",
    "    \n",
    "    for i, query in enumerate(queries):\n",
    "        best_idx = by_query_arg_max[query]\n",
    "        if best_idx == -1 or labels[best_idx] < labels[i]:\n",
    "            by_query_arg_max[query] = i\n",
    "    \n",
    "    binary_best_docs = np.zeros(shape=labels.shape)\n",
    "    for arg_max in by_query_arg_max.values():\n",
    "        binary_best_docs[arg_max] = 1.\n",
    "        \n",
    "    return binary_best_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "best_docs_train = get_best_documents(y_train, queries_train)\n",
    "best_docs_test = get_best_documents(y_test, queries_test)\n",
    "\n",
    "train_with_weights = Pool(\n",
    "    data=X_train,\n",
    "    label=best_docs_train,\n",
    "    group_id=queries_train,\n",
    "    group_weight=create_weights(queries_train)\n",
    ")\n",
    "\n",
    "test_with_weights = Pool(\n",
    "    data=X_test,\n",
    "    label=best_docs_test,\n",
    "    group_id=queries_test,\n",
    "    group_weight=create_weights(queries_test)\n",
    ")\n",
    "\n",
    "fit_model(\n",
    "    'QuerySoftMax',\n",
    "    additional_params={'custom_metric': 'AverageGain:top=1'},\n",
    "    train_pool=train_with_weights,\n",
    "    test_pool=test_with_weights\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reducing ploblem, step 2\n",
    "\n",
    "Now lets look at example of documents relevance:\n",
    "\n",
    "$$ \n",
    "    \\begin{align}\n",
    "    labels(q_1) &= \\begin{bmatrix}\n",
    "           4 \\\\\n",
    "           3 \\\\\n",
    "           3 \\\\\n",
    "           1\n",
    "         \\end{bmatrix},\n",
    "    labels(q_2) &= \\begin{bmatrix}\n",
    "           2 \\\\\n",
    "           1 \\\\\n",
    "           1 \\\\\n",
    "           0\n",
    "         \\end{bmatrix}\n",
    "   \\end{align}\n",
    "$$\n",
    "\n",
    "This means that with RMSE loss function we pay more attention to q1 than q2. \n",
    "\n",
    "To avoid this problem we introduce into RMSE a coefficient $c_q$ which depends only on query (and if fact equals to the mean of the difference between prediction and label).\n",
    "\n",
    "$$\\frac{1}{N}\\sqrt{ \\sum_q \\sum_{d_{qk}} \\left(f(d_{qk}) - l_{qk} - \\color{red}{c_{q}} \\right)^2 }$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fit_model('QueryRMSE')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reducing problem, step 3\n",
    "\n",
    "Since the goal of ranking is to predict a list of documents (which can be generated from given document relevances) RMSE loss function doesn't take into account relations between documents: the first is better than second, second is better than third and fifth etc.\n",
    "\n",
    "We can easily bring this information into the loss function, reducing problem not to regression but classification for two documents $(d_i, d_j)$ -- does $i$th better than $j$th or not.\n",
    "\n",
    "So we minimize the negative loglikelihood:\n",
    "\n",
    "$$ - \\sum_{i,j \\in Pairs} \\log \\left( \\frac{1}{1 + \\exp{-(f(d_i) - f(d_j))}} \\right) $$\n",
    "\n",
    "Methods based on pair comparisons called __pairwise__ in CatBoost this objective called __PairLogit__.\n",
    "\n",
    "There's no need to change the dataset CatBoost generate the pairs for us. The number of generating pairs managed via parameter max_size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fit_model('PairLogit')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also we can to specify the pairs directly. There are two ways to do that:\n",
    "\n",
    "1. Two-dimensional matrix with shape=(num_pairs, 2) $\\rightarrow$ (winner_id, loser_id): list, numpy.array, pandas.DataFrame.\n",
    "2. Path two the input file that contains pair descriptions:\n",
    "    * Row format: $\\texttt{[winner index, loser index, pair weight]}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def read_groups(file_name):\n",
    "    groups = {}\n",
    "    group_ids = []\n",
    "\n",
    "    with open(file_name) as f:\n",
    "        for doc_id, line in enumerate(f):\n",
    "            line = line.split(',')[:2]\n",
    "            \n",
    "            label, query_id = tuple(map(float, line))\n",
    "            if query_id not in groups:\n",
    "                groups[query_id] = []\n",
    "            groups[query_id].append((doc_id, label))\n",
    "\n",
    "            group_ids.append(query_id)\n",
    "\n",
    "    return groups, group_ids\n",
    "            \n",
    "train_groups, train_group_ids = read_groups(train_file)\n",
    "assert num_queries == len(train_groups)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pairs = []\n",
    "\n",
    "for group in train_groups.values():\n",
    "    for i in range(len(group)):\n",
    "        for j in range(i, len(group)):\n",
    "            if i == j:\n",
    "                continue\n",
    "            doc_i, relevance_i = group[i]\n",
    "            doc_j, relevance_j = group[j]\n",
    "            if relevance_i < relevance_j:\n",
    "                pairs.append((doc_j, doc_i))\n",
    "            else:\n",
    "                pairs.append((doc_i, doc_j))\n",
    "                \n",
    "pairs_file = os.path.join(data_dir, 'pairs.csv')\n",
    "\n",
    "with open(pairs_file, 'w') as f:\n",
    "    for pair in pairs:\n",
    "        f.write(str(pair[0]) + '\\t' + str(pair[1]) + '\\t1\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pool1 = Pool(data=X_train, label=y_train, group_id=train_group_ids, pairs=pairs)\n",
    "pool2 = Pool(data=train_file, column_description=description_file, pairs=pairs_file, delimiter=',')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reducing problem, step 3.1\n",
    "\n",
    "Thus we know that $f(d_{qk})$ is a ensemble of trees, we can accurately solve the minimization task from step 3.\n",
    "\n",
    "This method called __PairLogitPairwise__."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fit_model('PairLogitPairwise')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reducing problem, step 4\n",
    "\n",
    "Previous loss function directly minimize the number of pairs $(d_i, d_j)$ where $l_i > l_j$ but $f(d_i) < f(d_j)$, simply said the number of incorrectly placed documents.\n",
    "\n",
    "Since the user attention is high on the first documents and low on last the incorrect switch of the first two documents and last two has different cost.\n",
    "\n",
    "In steps 3 and 3.1 user can set the weight for pair.\n",
    "\n",
    "Method __YetiRank__ take this effect into account and generates weights for pairs according to their positions ([paper](https://cache-mskstoredata08.cdn.yandex.net/download.yandex.ru/company/to_rank_challenge_with_yetirank.pdf)).\n",
    "\n",
    "$$ - \\sum_{i,j \\in Pairs} \\color{red}{w_{ij}} \\log \\left( \\frac{1}{1 + \\exp{-(f(d_i) - f(d_j))}} \\right) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fit_model('YetiRank')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4.1\n",
    "\n",
    "As in step 3.1 __YetiRankPairwise__ is slower than __YetiRank__, but gives more accurate results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fit_model('YetiRankPairwise')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "widget = MetricVisualizer(['RMSE', 'QueryRMSE', 'PairLogit', 'PairLogitPairwise', 'YetiRank', 'YetiRankPairwise'])\n",
    "widget.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple classification\n",
    "\n",
    "Very fast $\\rightarrow$ very slow; Simple method $\\rightarrow$ complex method; Low quality $\\rightarrow$ high quality.\n",
    "\n",
    "1. RMSE\n",
    "2. QueryRMSE\n",
    "3. PairLogit\n",
    "4. PairLogitPairwise\n",
    "5. YetiRank\n",
    "6. YetiRankPairwise\n",
    "\n",
    "Besides our classification, the quality of the concrete method may depend on your dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look on NDCG metric of method YetiRank $-$ it's underfitted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fit_model('YetiRank', {'train_dir': 'YetiRank-lr-0.3', 'learning_rate': 0.3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "widget = MetricVisualizer(['YetiRank', 'YetiRank-lr-0.3'])\n",
    "widget.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Additional parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Metric period__\n",
    "\n",
    "Period in iterations of calculation metrics. This parameter can speed up training process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fit_model('YetiRank', {'metric_period': 50})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Task type__\n",
    "\n",
    "You can significantly speed up training procedure switching to gpu."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fit_model('YetiRank', {'task_type': 'GPU'})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
